# ------------------------------------------------------------------------------
# Plots outcome rate and algorithm choices by stent-hat-1
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 08_plot_outcomes.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(testit) # assert()

u <- modules::use(here("lib", "util.R"))
a <- modules::use(here("lib", "aesthetics.R"))
temp <- here("code", "08_train_split_model", "temp")
a$get_font("Optima", here("lib", "optima.ttf"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
daly_cost <- readRDS(paths$analysis$daly_cost) %>%
  select(ed_enc_id, cost, daly_200, daly_250, daly_300)
cohort <- readRDS(file.path(temp, "scored_test_cohort.rds")) %>%
  filter(!exclude) %>%
  u$safe_left_join(daly_cost)

# Cost-eff Cutoff --------------------------------------------------------------
message("Getting costeff cutoff for model 2...")
cost_eff_tbl <- cohort %>%
  filter(test_010_day) %>%
  group_by(tile_005_ensemble__2__tested) %>%
  summarize(
    n = n(),
    mean_risk = mean(p__ensemble__2),
    avg_cost = mean(cost, na.rm = TRUE),
    avg_daly_250 = mean(daly_250, na.rm = TRUE),
    cost_per_daly_250 = sum(cost, na.rm = TRUE) / sum(daly_250, na.rm = TRUE),
  ) %>%
  ungroup

print(cost_eff_tbl)
# only works if 150K threshold is between Q3 and Q4
fit <- lm(cost_per_daly_250 ~ mean_risk, data = cost_eff_tbl[3:4,])
risk_cutoff <- (150000 - coef(fit)["(Intercept)"])/coef(fit)["mean_risk"]
message("150K risk threshold = ", risk_cutoff)

# Get algorithmically chosen tests ---------------------------------------------
message("Simulating tests with ensemble_2...")
num_tests <- sum(cohort$test_010_day)
cohort <- cohort %>%
  mutate(risk_rank = rank(-p__ensemble__2)) %>%
  mutate(sim_test_010_day = risk_rank <= num_tests) %>%
  mutate(sim_test_costeff = p__ensemble__2 >= risk_cutoff)

# Plot outcome rates -----------------------------------------------------------
message("Getting outcome rates...")
test_rates <- cohort %>%
  group_by(tile_100_ensemble__1) %>%
  summarize(
    outcome_rate = mean(test_010_day),
    outcome_sum = sum(test_010_day),
    n = n(),
    std.deviation = sd(test_010_day),
    std.error = std.deviation/sqrt(n)
  ) %>%
  mutate(grouping = "Physician Tests") %>%
  mutate(cum_outcome_rate = cumsum(outcome_sum)/num_tests)
sim_test_rates <- cohort %>%
  group_by(tile_100_ensemble__1) %>%
  summarize(
    outcome_rate = mean(sim_test_010_day),
    outcome_sum = sum(sim_test_010_day),
    n = n(),
    std.deviation = sd(sim_test_010_day),
    std.error = std.deviation/sqrt(n)
  ) %>%
  mutate(grouping = "Algorithm Tests (Ensemble 2)") %>%
  mutate(cum_outcome_rate = cumsum(outcome_sum)/num_tests)

outcome_df <- test_rates %>% rbind(sim_test_rates) %>%
  setnames("tile_100_ensemble__1", "x_var")

message("Plotting...")
message("Tile plot...")
gg <- a$tile_plot(outcome_df) +
  labs(
    x = "Percentile of Predicted Risk (Ensemble 1)", y = "Test Rate",
    color = "Testing Decision", fill = "Testing Decision"
  ) +
  scale_x_continuous(breaks = seq(0,100,10))

message("150K cost-eff cutoff plot...")
sim_test_rates_150K <- cohort %>%
  group_by(tile_100_ensemble__1) %>%
  summarize(
    outcome_rate = mean(sim_test_costeff),
    outcome_sum = sum(sim_test_costeff),
    n = n(),
    std.deviation = sd(sim_test_costeff),
    std.error = std.deviation/sqrt(n)
  ) %>%
  mutate(grouping = "Algorithm Tests (Ensemble 2)")  %>%
  setnames("tile_100_ensemble__1", "x_var")

gg_150K <- a$tile_plot(sim_test_rates_150K) +
labs(
  x = "Percentile of Predicted Risk (Ensemble 1)", y = "Test Rate"
) +
scale_x_continuous(breaks = seq(0,100,10))

message("CDF plot...")
gg_cdf <- ggplot(
  outcome_df, aes(x = x_var, y = cum_outcome_rate, color = grouping)
) +
# geom_point() +
geom_line() +
labs(
  x = "Percentile of Predicted Risk",
  y = "Cumulative Percent of Tests", color = "Testing Decision"
) +
scale_x_continuous(breaks = seq(0,100,10)) +
theme_bw() +
theme(legend.position = "bottom",
      text = element_text(family = "Optima", size = 40)) +
scale_color_manual(values = a$disc_palette)

# Save -------------------------------------------------------------------------
message("Saving...")
save_fp <- file.path(temp, "test_rate_plot.png")
ggsave(plot = gg, filename = save_fp, width = 10, height = 7, unit = "in")

save_fp_150K <- file.path(temp, "sim_test_150K.png")
ggsave(
  plot = gg_150K, filename = save_fp_150K, width = 10, height = 7, unit = "in"
)

cdf_save_fp <- file.path(temp, "test_rate_cdf.png")
ggsave(
  plot = gg_cdf, filename = cdf_save_fp, width = 10, height = 7, unit = "in"
)

message("Done.")
